import importlib.util
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pandas as pd
import statsmodels.api as sm
from linearmodels.iv import IV2SLS


def load_shiftshare_module(root: Path):
    path = root / "analysis" / "19_iv_shiftshare_rollout_absorb.py"
    spec = importlib.util.spec_from_file_location("shiftshare", path)
    mod = importlib.util.module_from_spec(spec)
    assert spec.loader is not None
    spec.loader.exec_module(mod)
    return mod


def residualize_on(y: np.ndarray, X: np.ndarray) -> np.ndarray:
    beta, *_ = np.linalg.lstsq(X, y, rcond=None)
    return y - X @ beta


def zscore(x: np.ndarray) -> np.ndarray:
    x = np.asarray(x, dtype=float)
    m = float(np.nanmean(x))
    sd = float(np.nanstd(x, ddof=0))
    if not np.isfinite(sd) or sd <= 0:
        return x * 0
    return (x - m) / sd


def fit_ols_cluster(y: np.ndarray, X: np.ndarray, clusters: np.ndarray):
    Xc = sm.add_constant(X, has_constant="add")
    res = sm.OLS(y, Xc).fit(cov_type="cluster", cov_kwds={"groups": clusters})
    return res


def to_latex_tabular(df: pd.DataFrame) -> str:
    cols = [
        "outcome",
        "sample",
        "nobs",
        "b_fs",
        "se_fs",
        "F_fs",
        "b_rf",
        "se_rf",
        "b_iv",
        "se_iv",
    ]
    header = [
        "Outcome",
        "Sample",
        "N",
        "First stage $b$",
        "se",
        "$F$",
        "Reduced form $b$",
        "se",
        "2SLS $b$",
        "se",
    ]
    d = df[cols].copy()
    d["nobs"] = d["nobs"].map(lambda x: f"{x:.0f}" if pd.notna(x) else "")
    for c in ["b_fs", "se_fs", "b_rf", "se_rf", "b_iv", "se_iv"]:
        d[c] = d[c].map(lambda x: f"{x:.4f}" if pd.notna(x) else "")
    d["F_fs"] = d["F_fs"].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")

    out = []
    out.append("\\begin{tabular}{llrccccccc}")
    out.append("\\toprule")
    out.append(" & ".join(header) + " \\\\")
    out.append("\\midrule")
    for _, row in d.iterrows():
        out.append(" & ".join(str(row[c]) for c in cols) + " \\\\")
    out.append("\\bottomrule")
    out.append("\\end{tabular}")
    return "\n".join(out) + "\n"


@dataclass(frozen=True)
class ChainResult:
    outcome: str
    sample: str
    nobs: int
    b_fs: float
    se_fs: float
    F_fs: float
    b_rf: float
    se_rf: float
    b_iv: float
    se_iv: float


def build_chain(mod, df: pd.DataFrame, outcome: str, sample_label: str, controls: list[str], instr_cols: list[str]) -> ChainResult:
    cols = [outcome, "netusoft", "ct", "region", *controls, *instr_cols]
    d = df[cols].dropna().copy()
    if d.empty:
        raise ValueError(f"Empty estimation sample for outcome={outcome} sample={sample_label}")

    # Absorb region FE + country×year FE as in the main pipeline.
    model_cols = [outcome, "netusoft", *controls, *instr_cols]
    r = mod.absorb_two_way(d, model_cols, fe_a="region", fe_b="ct")

    y = r[outcome].to_numpy(dtype=float)
    d_endog = r["netusoft"].to_numpy(dtype=float)
    X = r[controls].to_numpy(dtype=float) if controls else np.zeros((len(r), 0), dtype=float)
    Z = r[instr_cols].to_numpy(dtype=float)
    clusters = d["region"].to_numpy()

    # Partial out controls from endog/outcome/instruments (Frisch-Waugh-Lovell).
    if X.shape[1] > 0:
        Xc = np.column_stack([np.ones(len(r)), X])
        y_tilde = residualize_on(y, Xc)
        d_tilde = residualize_on(d_endog, Xc)
        Z_tilde = np.column_stack([residualize_on(Z[:, j], Xc) for j in range(Z.shape[1])])
    else:
        y_tilde = y
        d_tilde = d_endog
        Z_tilde = Z

    # Build a scalar "instrument index" as the fitted values from the (excluded) first stage,
    # standardized to 1 SD. This collapses the multi-instrument set to a single, auditable index.
    pi_hat, *_ = np.linalg.lstsq(Z_tilde, d_tilde, rcond=None)
    z_index = zscore(Z_tilde @ pi_hat)

    # First stage: d_tilde on z_index
    fs = fit_ols_cluster(d_tilde, z_index.reshape(-1, 1), clusters=clusters)
    b_fs = float(fs.params[1])
    se_fs = float(fs.bse[1])
    F_fs = float((b_fs / se_fs) ** 2) if np.isfinite(b_fs) and np.isfinite(se_fs) and se_fs > 0 else float("nan")

    # Reduced form: y_tilde on z_index
    rf = fit_ols_cluster(y_tilde, z_index.reshape(-1, 1), clusters=clusters)
    b_rf = float(rf.params[1])
    se_rf = float(rf.bse[1])

    # 2SLS: y on netusoft, instrumented by z_index (with controls included as exog).
    exog = X if X.shape[1] > 0 else None
    iv = IV2SLS(y, exog=exog, endog=d_endog, instruments=z_index).fit(cov_type="clustered", clusters=clusters)
    b_iv = float(iv.params.iloc[0]) if hasattr(iv.params, "iloc") else float(iv.params["netusoft"])
    se_iv = float(iv.std_errors.iloc[0]) if hasattr(iv.std_errors, "iloc") else float(iv.std_errors["netusoft"])

    return ChainResult(
        outcome=outcome,
        sample=sample_label,
        nobs=int(iv.nobs),
        b_fs=b_fs,
        se_fs=se_fs,
        F_fs=F_fs,
        b_rf=b_rf,
        se_rf=se_rf,
        b_iv=b_iv,
        se_iv=se_iv,
    )


def main() -> None:
    root = Path(__file__).resolve().parents[1]
    out_dir = root / "outputs"
    out_dir.mkdir(parents=True, exist_ok=True)
    paper_tables = root / "paper_joc" / "tables"
    paper_tables.mkdir(parents=True, exist_ok=True)

    mod = load_shiftshare_module(root)
    df = mod.build_dataset(root)

    instr_cols = [c for c in getattr(mod, "INSTRUMENT_BASE_COLS", []) if c in df.columns]
    if not instr_cols:
        raise RuntimeError("No instrument columns found in dataset.")

    outcomes = [
        ("y_part_index", "Participation index"),
        ("y_vote", "Vote (yes/no)"),
    ]
    controls = ["agea", "gndr", "hinctnta"]

    specs = [
        ("All respondents", df),
        ("Low education", df[df["edu_high"] == 0].copy()),
        ("High education", df[df["edu_high"] == 1].copy()),
    ]

    rows = []
    for y, ylab in outcomes:
        for sample_label, dsub in specs:
            res = build_chain(mod, dsub, outcome=y, sample_label=sample_label, controls=controls, instr_cols=instr_cols)
            rows.append(
                {
                    "outcome": ylab,
                    "sample": res.sample,
                    "nobs": res.nobs,
                    "b_fs": res.b_fs,
                    "se_fs": res.se_fs,
                    "F_fs": res.F_fs,
                    "b_rf": res.b_rf,
                    "se_rf": res.se_rf,
                    "b_iv": res.b_iv,
                    "se_iv": res.se_iv,
                }
            )

    tab = pd.DataFrame(rows)
    (out_dir / "iv_chain_firststage_reducedform_2sls.csv").write_text(tab.to_csv(index=False), encoding="utf-8")
    (paper_tables / "iv_chain_firststage_reducedform_2sls.tex").write_text(to_latex_tabular(tab), encoding="utf-8")
    print(f"Wrote: {paper_tables / 'iv_chain_firststage_reducedform_2sls.tex'}")


if __name__ == "__main__":
    main()
